/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "ascir/Dialect/Asc/Transforms/Passes.h"
#include "InitFuncDef.h"

#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/SourceMgr.h"

#include <pybind11/cast.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // automatic casts between containers and python types

#define DEFINE_ADD_PASS(NAME, CONSTRUCTOR)                                     \
    m.def(NAME, [](PassManager &pm) { pm.addPass(CONSTRUCTOR()); })

#define DEFINE_ADD_PASS_ON(NEST, NAME, CONSTRUCTOR)                            \
    m.def(NAME, [](PassManager &pm) { pm.addNestedPass<NEST>(CONSTRUCTOR()); })

namespace py = pybind11;
using namespace mlir;

namespace {

void definePassManager(py::module &m)
{
    using namespace pybind11::literals;

    py::class_<PassManager>(m, "PassManager", py::module_local())
        .def(py::init<MLIRContext *>())
        .def("get_pipeline_str",
            [](PassManager &self) -> std::string {
              std::string result;
              llvm::raw_string_ostream os(result);
              self.printAsTextualPipeline(os);
              os.flush();
              return result;
            })
        .def("run",
            [](PassManager &self, ModuleOp &mod) {
              llvm::SourceMgr sourceMgr;
              SourceMgrDiagnosticHandler handler(sourceMgr, self.getContext());
              if (self.run(mod.getOperation()).failed())
                throw std::runtime_error("Failed to run passes");
            })
        .def("enable_verifier",
            [](PassManager &self, bool enable) { self.enableVerifier(enable); },
            "enable"_a = true)
        .def("enable_printing", [](PassManager &self) {
          OpPrintingFlags flags;
          flags.enableDebugInfo(true);
          self.enableIRPrinting([](Pass *, Operation *) { return true; }, /*shouldPrintBeforePass*/
              [](Pass *, Operation *) { return true; }, /*shouldPrintAfterPass*/
              false, /*printModuleScope*/
              false, /*printAfterOnlyOnChange*/
              true, /*printAfterOnlyOnFailure*/
              llvm::errs(), /*out*/
              flags /*opPrintingFlags*/
              );
        });
}

void defineCommonPasses(py::module &mod)
{
    auto m = mod.def_submodule("common");
    DEFINE_ADD_PASS("add_canonicalizer", createCanonicalizerPass);
    DEFINE_ADD_PASS("add_cse", createCSEPass);
    DEFINE_ADD_PASS("add_inliner", createInlinerPass);
    DEFINE_ADD_PASS("add_licm", createLoopInvariantCodeMotionPass);
    DEFINE_ADD_PASS("add_print_ir", createPrintIRPass);
    DEFINE_ADD_PASS("add_reconcile_unrealized_casts",
                    createReconcileUnrealizedCastsPass);
    DEFINE_ADD_PASS("add_sccp", createSCCPPass);
    DEFINE_ADD_PASS("add_strip_debug_info", createStripDebugInfoPass);
    DEFINE_ADD_PASS("add_symbol_dce", createSymbolDCEPass);
}

void defineAscendCPasses(py::module &mod)
{
    using namespace ascendc;
    auto m = mod.def_submodule("ascendc");
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_noop_pass", createNoopPass);
    DEFINE_ADD_PASS("add_detect_kernel_type", createDetectKernelTypePass);
    DEFINE_ADD_PASS("add_declare_py_struct", createDeclarePyStructPass);
    DEFINE_ADD_PASS("add_define_cube_only", createDefineCubeOnlyPass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_erase_sync", createEraseSyncPass);
    DEFINE_ADD_PASS("add_generate_boilerplate", createGenerateBoilerplatePass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_hoist_que_bind", createHoistQueBindPass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_hoist_ub_allocation", createHoistUBAllocationPass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_input_output_tensor", createInputOutputTensorPass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_insert_sync", createInsertSyncPass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_materialize_tensor", createMaterializeTensorPass);
    DEFINE_ADD_PASS("add_legalize_kernel_args", createLegalizeKernelArgsPass);
    DEFINE_ADD_PASS("add_privatize_func", createPrivatizeFuncPass);
    DEFINE_ADD_PASS("add_detect_enable_debug", createDetectEnableDebugPass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_unify_pipe", createUnifyPipePass);
    DEFINE_ADD_PASS_ON(func::FuncOp, "add_verify_sync", createVerifySyncPass);
}

} // namespace

namespace pybind11 {
namespace asc {
void pyasc_init_passes(py::module &&m)
{
    definePassManager(m);
    defineCommonPasses(m);
    defineAscendCPasses(m);
}
}   // namespace asc
}   // namespace pybind11

#undef DEFINE_ADD_PASS
#undef DEFINE_ADD_PASS_ON
